// @HEADER
//*********************************************************************//
//  SiYuan: A numerical PDE solver                                     //
//  Copyright (2022) YUAN Xi                                           //
//  This Software is released under the BSD 2-Clause license detailed  //
//  in the file "LICENSE" in the top-level SiYuan directory            //
//*********************************************************************//
// @HEADER

#ifndef _EIGENSOLVER_IMPL_HPP
#define _EIGENSOLVER_IMPL_HPP

#include "Teuchos_XMLParameterListHelpers.hpp"
#include "Teuchos_TestForException.hpp"
#include "Teuchos_ScalarTraits.hpp"

#include "Thyra_TpetraThyraWrappers.hpp"

#include "Tpetra_Map.hpp"
#include "Tpetra_CrsGraph.hpp"
#include "Tpetra_CrsMatrix.hpp"
#include "Tpetra_Import.hpp"
#include "Tpetra_Export.hpp"
#include "TpetraExt_MatrixMatrix.hpp"

#include "AnasaziConfigDefs.hpp"
#include "AnasaziBasicEigenproblem.hpp"
#include "AnasaziFactory.hpp"
#include "AnasaziBasicOutputManager.hpp"
#include "AnasaziTpetraAdapter.hpp"


namespace Anasazi {

template<class Scalar>
TpetraGenOp<Scalar>::TpetraGenOp (const Teuchos::RCP<Amesos2::Solver<CrsMatrix,MV>>& solver,
             const Teuchos::RCP<CrsMatrix>& massMtx,
             const bool useTranspose)
  : solver_ (solver),
    massMtx_ (massMtx),
    useTranspose_ (useTranspose)
{
  if (solver.is_null ()) {
    throw std::invalid_argument ("TpetraGenOp constructor: The 'solver' "
                                 "input argument is null.");
  }
  if (massMtx.is_null ()) {
    throw std::invalid_argument ("TpetraGenOp constructor: The 'massMtx' "
                                 "input argument is null.");
  }

 /* if (solver_->UseTranspose ()) {
    solver_->SetUseTranspose (! useTranspose);
  } else {
    solver_->SetUseTranspose (useTranspose);
  }

  if (massMtx_->UseTranspose ()) {
    massMtx_->SetUseTranspose (! useTranspose);
  } else {
    massMtx_->SetUseTranspose (useTranspose);
  }*/
}

// input; right hand X; return solution Y
template<class Scalar>
void
TpetraGenOp<Scalar>::apply(const Tpetra::MultiVector<Scalar> &X, Tpetra::MultiVector<Scalar> &Y,
           Teuchos::ETransp mode, Scalar alpha,  Scalar beta ) const
{
  if (massMtx_.is_null ()) {
    throw std::logic_error ("TpetraGenOp::Apply: massMtx_ is null");
  }
  if (solver_.is_null ()) {
    throw std::logic_error ("TpetraGenOp::Apply: solver_ is null");
  }

  if (! useTranspose_) {
    // Storage for M*X
    MV MX(X.getMap(), X.getNumVectors() );

    // Apply M*X , put constrainted X =0 here?
    massMtx_->apply(X, MX);
    Y.putScalar(0.0);

    Teuchos::RCP<Teuchos::FancyOStream> pout = Teuchos::getFancyOStream(Teuchos::rcpFromRef(std::cout));
  //  Y.describe(*pout, Teuchos::VERB_EXTREME);

    // AX=B
//    solver_->setB(&MX);
//    solver_->setX(&Y);

    // Solve the linear system K*Y = MX => Y=K^-1 *M*X
    solver_->solve(&Y,&MX);
  //  Y.describe(*pout, Teuchos::VERB_EXTREME);
  }
  else { // apply the transposed operator
    // Storage for A^{-T}*X
    MV ATX(X.getMap (), X.getNumVectors ());
    MV tmpX = const_cast<MV&> (X);

    // Set the LHS and RHS
    solver_->setB(&tmpX);
    solver_->setX(&ATX);

    // Solve the linear system K^T*Y = X
    solver_->solve();

    // Apply M*ATX
    massMtx_->apply(ATX, Y);
  }
}

} // namespace Anasazi



namespace SiYuan {

//---------------------------------------------
//
//  Following EigenSolver
//
//---------------------------------------------

// Nonmember constuctors

template<class Scalar>
Teuchos::RCP<EigenSolver<Scalar> >
buildEigenSolver( const Teuchos::RCP<Teuchos::ParameterList>& pl,
  const Teuchos::RCP<Thyra::ModelEvaluator<Scalar> > &model,
  const std::shared_ptr< EigenObserver_WriteToExodus<Scalar> > &observer)
{
  return Teuchos::rcp(new EigenSolver<Scalar>(pl, model, observer));
}


// Initializers/Accessors


template<class Scalar>
EigenSolver<Scalar>::EigenSolver( const Teuchos::RCP<Teuchos::ParameterList>& pl,
  const Teuchos::RCP<Thyra::ModelEvaluator<Scalar> > &model,
  const std::shared_ptr< EigenObserver_WriteToExodus<Scalar> >& observer
  ) : model_(model), observer_(observer)
{
  MEB::InArgsSetup<Scalar> inArgs = model->createInArgs();;
  MEB::OutArgsSetup<Scalar> outArgs = model->createOutArgs();;
  num_p_ = inArgs.Np();   // Number of *vectors* of parameters
  num_g_ = outArgs.Ng();  // Number of *vectors* of responses

  x_space_ = model->get_x_space();

  which = pl->get<std::string>("Which", "SM");  // Targetted eigenvalues (SM,LM,SR,LR,SI,or LI)
    //  Smallest Magnitude(SM), Largest Magnitude(LM), Smallest Real(SR), Largest Real(LR), Smallest Imaginary(SI), Largest Imaginary(LI)
  bHermitian = pl->get<bool>("Symmetric",true);
  nev = pl->get<int>("Num Eigenvalues",10);
  blockSize = pl->get<int>("Block Size",5);
  maxIters = pl->get<int>("Maximum Iterations",500);
  conv_tol = pl->get<double>("Convergece Tolerance",1.0e-8);
  method = pl->get<std::string>("Method", "Block Davidson");
  sigma = pl->get<double>("Sigma",0.0);
  int numRestartBlocks = 2*nev/blockSize;
  nrestart = pl->get<int>("Num Restart Blocks",numRestartBlocks);
  K_pivot = pl->get<double>("K pivot",100.0);

  verbose = pl->get<bool>("Verbose", false);
  debug = pl->get<bool>("Debug", false);
  if( debug ) verbose = true;
}


// Public functions overridden from ModelEvaulator

template<class Scalar>
Teuchos::RCP<const Thyra::VectorSpaceBase<Scalar> >
EigenSolver<Scalar>::get_p_space(int l) const
{
  return model_->get_p_space(l);
}


template<class Scalar>
Teuchos::RCP<const Teuchos::Array<std::string> >
EigenSolver<Scalar>::get_p_names(int j) const
{
  return model_->get_p_names(j);
}

template<class Scalar>
Teuchos::RCP<const Thyra::VectorSpaceBase<Scalar> >
EigenSolver<Scalar>::get_g_space(int j) const
{
  if (j == num_g_) return model_->get_x_space(); //last response vector is solution (same map as x)
  else return model_->get_g_space(j);
}

template<class Scalar>
Teuchos::ArrayView<const std::string>
EigenSolver<Scalar>::get_g_names(int l) const
{
  return model_->get_g_names(l);
}


template<class Scalar>
Thyra::ModelEvaluatorBase::InArgs<Scalar>
EigenSolver<Scalar>::createInArgs() const
{
  MEB::InArgsSetup<Scalar> inArgs;
  inArgs.setModelEvalDescription(this->description());
  inArgs.set_Np(num_p_);
  return inArgs;
}


// Private functions overridden from ModelEvaulatorDefaultBase

template<class Scalar>
Thyra::ModelEvaluatorBase::OutArgs<Scalar>
EigenSolver<Scalar>::createOutArgsImpl() const
{
  MEB::OutArgsSetup<Scalar> outArgs;
  outArgs.setModelEvalDescription(this->description());

  // Ng is 1 bigger then model's Ng so that the solution vector can be an outarg
  outArgs.set_Np_Ng(num_p_, num_g_+1);

  //Derivative info 
  MEB::OutArgsSetup<Scalar> model_outArgs = model_->createOutArgs();
  for (int i=0; i<num_g_; i++) {
    for (int j=0; j<num_p_; j++)
      outArgs.setSupports(MEB::OUT_ARG_DgDp, i, j, model_outArgs.supports(MEB::OUT_ARG_DgDp, i, j));
  }

  return outArgs;
}


template<class Scalar>
void EigenSolver<Scalar>::evalModelImpl(
  const Thyra::ModelEvaluatorBase::InArgs<Scalar> &inArgs,
  const Thyra::ModelEvaluatorBase::OutArgs<Scalar> &outArgs
  ) const
{
  // type definitions
  typedef Tpetra::CrsMatrix<Scalar>                     CrsMatrix;
  typedef Thyra::TpetraOperatorVectorExtraction<Scalar> TOE;
  typedef Thyra::MultiVectorBase<Scalar>                YMVB;
  typedef Tpetra::MultiVector<Scalar>                   TMV;
  typedef Tpetra::Operator<Scalar>                      TOP;
//  typedef Anasazi::TpetraGenOp<Scalar>                  TOP;
  typedef Anasazi::MultiVecTraits<Scalar,TMV>           TMVT;
  typedef Anasazi::OperatorTraits<Scalar,TMV,Tpetra::Operator<Scalar>>       TOPT;
  typedef Teuchos::ScalarTraits<Scalar>                 SCT;

  Teuchos::rcp_dynamic_cast< panzer::ModelEvaluator<Scalar> >(model_)->setKPivot(K_pivot);
  
  // Get the stiffness and mass matrices
  MEB::InArgsSetup<Scalar> model_inArgs = model_->createInArgs();
  MEB::OutArgsSetup<Scalar> model_outArgs = model_->createOutArgs();

  //input args
  model_inArgs.set_t(0.0);

//  Teuchos::RCP<const Thyra::VectorBase<Scalar> > x = model_inArgs.get_x();
  Teuchos::RCP<Thyra::VectorBase<Scalar> > x = Thyra::createMember(x_space_);
  Thyra::assign(x.ptr(),0.0);
  model_inArgs.set_x(x);
  Teuchos::RCP<Thyra::VectorBase<Scalar> > x_dot = Thyra::createMember(x_space_);
  Thyra::assign(x_dot.ptr(),0.0);
  model_inArgs.set_x_dot(x_dot);

  model_inArgs.set_alpha(0.0);
  model_inArgs.set_beta(1.0);

//  Teuchos::RCP<const S_Pivot> pivot = Teuchos::rcp(new S_Pivot(K_pivot));
//  model_inArgs.set(pivot);
//  model_inArgs.template setSupports<S_Pivot>();
 // model_inArgs.set_pivot_dirichlet(K_pivot);

  for(int i=0; i<num_p_; i++)
    model_inArgs.set_p(i, inArgs.get_p(i));

  //output args
  Teuchos::RCP< Thyra::LinearOpBase<Scalar> > K_op = model_->create_W_op();
  model_outArgs.set_W_op(K_op); 

  model_->evalModel(model_inArgs, model_outArgs); //compute K matrix

  // reset alpha and beta to compute the mass matrix
  model_inArgs.set_alpha(1.0);
  model_inArgs.set_beta(0.0);
 // model_inArgs.set_pivot_dirichlet(1.0);
  Teuchos::RCP< Thyra::LinearOpBase<Scalar> > M_op = model_->create_W_op();
  model_outArgs.set_W_op(M_op);

  model_->evalModel(model_inArgs, model_outArgs); //compute M matrix

  Teuchos::RCP< const CrsMatrix > K_tpetra = 
    Teuchos::rcp_dynamic_cast<CrsMatrix>( TOE::getTpetraOperator(K_op) );
  Teuchos::RCP< CrsMatrix > M_tpetra =
    Teuchos::rcp_dynamic_cast<CrsMatrix>( TOE::getTpetraOperator(M_op) );

//  Tpetra::MatrixMarket::Writer<Tpetra::CrsMatrix<Scalar>>::writeSparseFile("K1.mm", *K_tpetra);
//  Tpetra::MatrixMarket::Writer<Tpetra::CrsMatrix<Scalar>>::writeSparseFile("M1.mm", *M_tpetra);

 // if( sigma!=0.0 ) {
  //  K_tpetra->resumeFill();
  //  Tpetra::MatrixMatrix::Add<Scalar>( *M_tpetra, false, -sigma, *K_tpetra, 1.0);
  //  K_tpetra->fillComplete();
  //  double a =1.0;
  //  Teuchos::RCP< CrsMatrix > KM = Tpetra::MatrixMatrix::add( sigma, false, *M_tpetra, a, false, *K_tpetra);
  //  K_tpetra = KM;
 // }

//  Teuchos::RCP<Teuchos::FancyOStream> pout = Teuchos::getFancyOStream(Teuchos::rcpFromRef(std::cout));
//  K_tpetra->describe(*pout, Teuchos::VERB_EXTREME);
//  M_tpetra->describe(*pout, Teuchos::VERB_EXTREME);

  Teuchos::RCP<TMV> ivec = Teuchos::rcp( new TMV(K_tpetra->getRowMap(), blockSize) );
  ivec->randomize();

  /*Teuchos::RCP<Amesos2::Solver<CrsMatrix,TMV>> AmesosSolver = Amesos2::create<CrsMatrix,TMV>("Klu", K_tpetra);
  if (AmesosSolver.is_null ()) {
    throw std::runtime_error ("Amesos appears not to have any solvers enabled.");
  }
  // The TpetraGenOp class assumes that the symbolic and numeric
  // factorizations have already been performed on the linear problem.
  AmesosSolver->symbolicFactorization().numericFactorization();

  Teuchos::RCP<TOP> Aop = Teuchos::rcp( new TOP(AmesosSolver, M_tpetra) );
  Teuchos::RCP<Anasazi::BasicEigenproblem<Scalar, TMV, TOP> > eigenProblem =
    Teuchos::rcp( new Anasazi::BasicEigenproblem<Scalar, TMV, TOP>(Aop, ivec) );*/

  // Create the eigenproblem.
  Teuchos::RCP<Anasazi::BasicEigenproblem<Scalar, TMV, TOP> > eigenProblem =
    Teuchos::rcp( new Anasazi::BasicEigenproblem<Scalar, TMV, TOP>(K_tpetra, M_tpetra, ivec) );

  // Inform the eigenproblem that the operator A is symmetric
  eigenProblem->setHermitian(bHermitian);

  // Set the number of eigenvalues requested
  eigenProblem->setNEV( nev );

  // Inform the eigenproblem that you are finishing passing it information
  bool bSuccess = eigenProblem->setProblem();
  TEUCHOS_TEST_FOR_EXCEPTION(!bSuccess, Teuchos::Exceptions::InvalidParameter,
     "Anasazi::BasicEigenproblem::setProblem() returned an error.\n" << std::endl);

  // Set verbosity level
  int verbosity = Anasazi::Errors + Anasazi::Warnings + Anasazi::FinalSummary + Anasazi::TimingDetails;
  if (verbose) {
    verbosity += Anasazi::IterationDetails;
  }
  if (debug) {
    verbosity += Anasazi::Debug;
  }

  // Create parameter list to pass into the Anasazi solver manager
  //
  Teuchos::ParameterList eigenPL;
  eigenPL.set( "Which", which );
  eigenPL.set( "Block Size", blockSize );
  eigenPL.set( "Maximum Iterations", maxIters );
  eigenPL.set( "Convergence Tolerance", conv_tol );
  eigenPL.set( "Full Ortho", true );
  eigenPL.set( "Use Locking", true );
  eigenPL.set( "Verbosity", verbosity );
  eigenPL.set( "Num Restart Blocks", nrestart );

  // Create the solver manager
  Teuchos::RCP<Anasazi::SolverManager<Scalar, TMV, TOP> > MyEigenSolver
    = Anasazi::Factory::create( method,eigenProblem,eigenPL );
//  Anasazi::BlockDavidsonSolMgr<Scalar, TMV, TOP> eigenSolverMan; 

   // Solve the problem
  Anasazi::ReturnType returnCode = MyEigenSolver->solve();
  TEUCHOS_TEST_FOR_EXCEPTION(returnCode != Anasazi::Converged, Teuchos::Exceptions::InvalidParameter,
    "Anasazi solver not converged.\n" << std::endl);

  // Get the eigenvalues and eigenvectors from the eigenproblem
  Anasazi::Eigensolution<Scalar,TMV> sol = eigenProblem->getSolution();
  std::vector<Anasazi::Value<Scalar> > evals = sol.Evals;
  Teuchos::RCP<TMV> evecs = sol.Evecs;

  std::vector<double> evals_real(sol.numVecs);
  for(int i=0; i<sol.numVecs; i++) evals_real[i] = evals[i].realpart;

  // Compute residuals.
  std::vector<double> normR(sol.numVecs);
  if (sol.numVecs > 0) {
    Teuchos::SerialDenseMatrix<int,double> T(sol.numVecs, sol.numVecs);
    TMV Kvec( K_tpetra->getRowMap(), evecs->getNumVectors() );
    TMV Mvec( M_tpetra->getRowMap(), evecs->getNumVectors() );
    T.putScalar(SCT::zero()); 
    for (int i=0; i<sol.numVecs; i++) {
      T(i,i) = evals_real[i];
    }
    TOPT::Apply(*K_tpetra, *evecs, Kvec );
    TOPT::Apply(*M_tpetra, *evecs, Mvec );
    TMVT::MvTimesMatAddMv( -SCT::one(), Mvec, T, SCT::one(), Kvec );
    TMVT::MvNorm( Kvec, normR );
  }

  // Print the results
  std::ostringstream os;
  os.setf(std::ios_base::right, std::ios_base::adjustfield);
  os<<"Solver manager returned " << (returnCode == Anasazi::Converged ? "converged." : "unconverged.") << std::endl;
  os<<std::endl;
  os<<"------------------------------------------------------"<<std::endl;
  os<<std::setw(16)<<"Eigenvalue"
    <<std::setw(18)<<"Direct Residual"
    <<std::endl;
  os<<"------------------------------------------------------"<<std::endl;
  for (int i=0; i<sol.numVecs; i++) {
    os<<std::setw(16)<<evals_real[i]
      <<std::setw(18)<<normR[i]/evals_real[i]
      <<std::endl;
  }
  os<<"------------------------------------------------------"<<std::endl;

  std::cout << Anasazi::Anasazi_Version() << std::endl << std::endl;
  std::cout << os.str();

  observer_ -> observeSolution( sol );
}

} // namespace SiYuan

#endif
